Skip to content

Convert common sub-functions as common sub-expressions #2788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

blegat
Copy link
Member

@blegat blegat commented Jul 22, 2025

Follow up from jump-dev/JuMP.jl#4032.
At the moment, you can have a small model in terms of memory footprint of the MOI.ScalarNonlinearFunction. However, when you add it to the MOI level, the AD tape can be exponentially bigger.
With this PR, the AD tape has the same memory footprint as the MOI.ScalarNonlinearFunction simply by interpreting the use of the same (in terms of the same pointer in memory, no expensive comparison is done) sub-functions as sub-expressions.

What's a little bit tricky to handle is that the first time to see an expression, you're just going to add it to the tape so the second time you see it, you need to replace it in the tape as a subexpression and update the indices of variables after the tape.
Since the sub-expression is contiguous in the tape, it's fortunately not too hard to do.

It also prevents the copy of MOI.ScalarNonlinearFunctions, what's the catch there ?

Closes jump-dev/JuMP.jl#4024

@blegat
Copy link
Member Author

blegat commented Jul 22, 2025

Consider the benchmark

using JuMP

f(x, u) = [sin(x[1]) - x[1] * u, cos(x[2]) + x[1] * u]
function RK4(f, X, u)
    k1 = f(X     , u)
    k2 = f(X+k1/2, u)
    k3 = f(X+k2/2, u)
    k4 = f(X+k3  , u)
    X + (k1 + 2k2 + 2k3 + k4) / 6
end

import Ipopt
function bench(n)
    model = direct_model(Ipopt.Optimizer())

    @variable(model, q[1:2])
    @variable(model, u)

    x = q
    for _ = 1:n
        x = RK4(f, x, u)
    end

    @constraint(model, x .== 0);

    @objective(model, Min, u^2)

    @time optimize!(model)
end
@time bench(4)

When combined with jump-dev/JuMP.jl#4032, I now get:

  0.002103 seconds (3.79 k allocations: 226.328 KiB)
  1.121932 seconds (9.58 k allocations: 513.109 KiB)

All the time is actually being spent by check_belongs_to_model in https://github.com/jump-dev/JuMP.jl/blob/b19c3e71e74e87bbf51d84cb1cfb94b0d8e42700/src/constraints.jl#L1037, if I comment out this line, I get:

  0.002041 seconds (3.79 k allocations: 226.328 KiB)
  0.002394 seconds (9.55 k allocations: 512.219 KiB)

This could be fixed by checking the model at the same time as we do moi_function.

Another caveat is that we need to use direct_model, otherwise, map_indices in

MOI.add_constraint(dest, map_indices(index_map, f), s)

ruins everything.

@blegat
Copy link
Member Author

blegat commented Jul 22, 2025

I fixed the performance issue of check_belongs_to_model and the issue of map_indices creating duplicates, it works now even without direct model

using JuMP

f(x, u) = [sin(x[1]) - x[1] * u, cos(x[2]) + x[1] * u]
function RK4(f, X, u)
    k1 = f(X     , u)
    k2 = f(X+k1/2, u)
    k3 = f(X+k2/2, u)
    k4 = f(X+k3  , u)
    X + (k1 + 2k2 + 2k3 + k4) / 6
end

import Ipopt
function bench(n)
    model = Model(Ipopt.Optimizer)

    @variable(model, q[1:2])
    @variable(model, u)

    x = q
    for _ = 1:n
        x = RK4(f, x, u)
    end

    @constraint(model, x .== 0);

    @objective(model, Min, u^2)

    @time optimize!(model)
end
bench(1)
bench(2)
bench(3)
@time bench(4)

I get

  0.001918 seconds (6.61 k allocations: 426.477 KiB)
  0.002599 seconds (11.97 k allocations: 614.852 KiB)

@odow odow marked this pull request as draft July 22, 2025 21:41
@odow
Copy link
Member

odow commented Jul 22, 2025

There's a cost to adding subexpressions. I need to think very carefully if we should do this. It's not an obvious win. I've been procrastinating on this, not because it is technically difficult, but because I'm not sure if it is something that MOI should even do. Users can manually extract their subexpressions if they desire.

I think at minimum, we're going to need a much larger set of benchmarks.

@blegat
Copy link
Member Author

blegat commented Jul 22, 2025

I agree, it's not a clear choice. What convinced me to write this PR is the following. When users share sub-expressions by reference, we either:

  1. Exponentially blow up the problem size (current)
  2. Handle it proportionally (this PR)

The exponential blowup can make problems intractable. We have workarounds but most users won't think about them, they will just have their computer freeze. On the other hand, having more sub-expressions than necessary can at worse be a bit slower. And we're not inventing sub-expressions, the user already had them in their code, either unintentionally or for saving memory.

Post-hoc detection (like what you did in https://github.com/lanl-ansi/MathOptSymbolicAD.jl) is complementary as it catches more sub-expressions but it doesn't prevent the blowup during construction.

parent_node,
)
data.cache[arg] = (expr, length(expr.nodes))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think we can tidy up all of this, but I'm not opposed to this part necessarily. We could be cleverer in how we do this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this version should be working and tested so now it could be refactored :)

else
parent.args[i] = MOI.Utilities.map_indices(index_map, arg)
end
end
if !isnothing(nl_cache)
nl_cache[f] = root
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part sucks. But we could work something out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the fact that we should add a keyword argument to the function ?

@@ -103,6 +103,9 @@ function MOI.get(
) where {F,S}
MOI.throw_if_not_valid(v, ci)
f, _ = v.constraints[ci]::Tuple{F,S}
if f isa MOI.ScalarNonlinearFunction
return f
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is an issue that is quite hard to understand the implications of. We haven't exactly codified the contract of where functions get copied throughout the interface, because we assumed that it was cheap.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One think we could do is to call MA.copy_if_mutable. It should work since

julia> MOI.MA.mutability(MOI.ScalarNonlinearFunction)
MutableArithmetics.IsNotMutable()

julia> MOI.MA.mutability(MOI.ScalarAffineFunction{Float64})
MutableArithmetics.IsMutable()

@odow
Copy link
Member

odow commented Aug 5, 2025

I am happy if this works only in direct-mode. I need to think about this more. If we can build the tape fast, but it's slow to evaluate, then I'd consider writing a new AD backend that is structured as a complete DAG, rather than the current function-based approach.

@blegat
Copy link
Member Author

blegat commented Aug 6, 2025

Do we have a benchmark showing that the subexpressions are slow ? I think you shared one at some point but I don't remember if it showed that it was slower or just comparable so not worth it. We can definitely split this PR in smaller ones.

@blegat blegat marked this pull request as ready for review August 6, 2025 08:26
@blegat
Copy link
Member Author

blegat commented Aug 6, 2025

This part should be enough for direct mode, the other parts were moved to #2802 and #2803

@odow odow marked this pull request as draft August 7, 2025 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

Performance problem with deeply nested nonlinear expressions
2 participants